{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "uuid": "b17eeed9-0c38-47ae-96a5-25b1f9baf9a4"
   },
   "outputs": [],
   "source": [
    "#! -*- coding:utf-8 -*-\n",
    "# 句子对分类任务，LCQMC数据集\n",
    "# val_acc: 0.887071, test_acc: 0.870320\n",
    "\n",
    "import numpy as np\n",
    "from bert4keras.backend import keras, set_gelu, K\n",
    "from bert4keras.tokenizers import Tokenizer\n",
    "from bert4keras.models import build_transformer_model\n",
    "from bert4keras.optimizers import Adam\n",
    "from bert4keras.snippets import sequence_padding, DataGenerator\n",
    "from bert4keras.snippets import open\n",
    "from keras.layers import Dropout, Dense\n",
    "\n",
    "set_gelu('tanh')  # 切换gelu版本\n",
    "\n",
    "maxlen = 128\n",
    "batch_size = 64\n",
    "config_path = 'bert_config.json'\n",
    "checkpoint_path = 'bert_model.ckpt'\n",
    "dict_path = 'vocab.txt'\n",
    "\n",
    "\n",
    "def load_data(filename):\n",
    "    D = []\n",
    "    with open(filename, encoding='utf-8') as f:\n",
    "        for l in f:\n",
    "            text1, text2, label = l.strip().split('\\t')\n",
    "            D.append((text1, text2, int(label)))\n",
    "    return D\n",
    "\n",
    "\n",
    "# 加载数据集\n",
    "train_data = load_data('lcqmc.train.data')\n",
    "valid_data = load_data('lcqmc.valid.data')\n",
    "test_data = load_data('lcqmc.test.data')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "uuid": "5a177493-cb23-415b-bc39-d855c7c7b9a4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_2\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "Input-Token (InputLayer)        (None, None)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "Input-Segment (InputLayer)      (None, None)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Token (Embedding)     (None, None, 768)    16226304    Input-Token[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Token-Segment (Add)   (None, None, 768)    0           Embedding-Token[0][0]            \n",
      "                                                                 Embedding-Segment[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Position (PositionEmb (None, None, 768)    393216      Embedding-Token-Segment[0][0]    \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Norm (LayerNormalizat (None, None, 768)    1536        Embedding-Position[0][0]         \n",
      "__________________________________________________________________________________________________\n",
      "Embedding-Dropout (Dropout)     (None, None, 768)    0           Embedding-Norm[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-MultiHeadSelfAtte (None, None, 768)    2362368     Embedding-Dropout[0][0]          \n",
      "                                                                 Embedding-Dropout[0][0]          \n",
      "                                                                 Embedding-Dropout[0][0]          \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-MultiHeadSelfAtte (None, None, 768)    0           Transformer-0-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-MultiHeadSelfAtte (None, None, 768)    0           Embedding-Dropout[0][0]          \n",
      "                                                                 Transformer-0-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-0-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-FeedForward (Feed (None, None, 768)    4722432     Transformer-0-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-FeedForward-Dropo (None, None, 768)    0           Transformer-0-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-FeedForward-Add ( (None, None, 768)    0           Transformer-0-MultiHeadSelfAttent\n",
      "                                                                 Transformer-0-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-0-FeedForward-Norm  (None, None, 768)    1536        Transformer-0-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-0-FeedForward-Norm[0]\n",
      "                                                                 Transformer-0-FeedForward-Norm[0]\n",
      "                                                                 Transformer-0-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-MultiHeadSelfAtte (None, None, 768)    0           Transformer-1-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-MultiHeadSelfAtte (None, None, 768)    0           Transformer-0-FeedForward-Norm[0]\n",
      "                                                                 Transformer-1-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-1-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-FeedForward (Feed (None, None, 768)    4722432     Transformer-1-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-FeedForward-Dropo (None, None, 768)    0           Transformer-1-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-FeedForward-Add ( (None, None, 768)    0           Transformer-1-MultiHeadSelfAttent\n",
      "                                                                 Transformer-1-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-1-FeedForward-Norm  (None, None, 768)    1536        Transformer-1-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-1-FeedForward-Norm[0]\n",
      "                                                                 Transformer-1-FeedForward-Norm[0]\n",
      "                                                                 Transformer-1-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-MultiHeadSelfAtte (None, None, 768)    0           Transformer-2-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-MultiHeadSelfAtte (None, None, 768)    0           Transformer-1-FeedForward-Norm[0]\n",
      "                                                                 Transformer-2-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-2-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-FeedForward (Feed (None, None, 768)    4722432     Transformer-2-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-FeedForward-Dropo (None, None, 768)    0           Transformer-2-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-FeedForward-Add ( (None, None, 768)    0           Transformer-2-MultiHeadSelfAttent\n",
      "                                                                 Transformer-2-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-2-FeedForward-Norm  (None, None, 768)    1536        Transformer-2-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-2-FeedForward-Norm[0]\n",
      "                                                                 Transformer-2-FeedForward-Norm[0]\n",
      "                                                                 Transformer-2-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-MultiHeadSelfAtte (None, None, 768)    0           Transformer-3-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-MultiHeadSelfAtte (None, None, 768)    0           Transformer-2-FeedForward-Norm[0]\n",
      "                                                                 Transformer-3-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-3-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-FeedForward (Feed (None, None, 768)    4722432     Transformer-3-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-FeedForward-Dropo (None, None, 768)    0           Transformer-3-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-FeedForward-Add ( (None, None, 768)    0           Transformer-3-MultiHeadSelfAttent\n",
      "                                                                 Transformer-3-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-3-FeedForward-Norm  (None, None, 768)    1536        Transformer-3-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-3-FeedForward-Norm[0]\n",
      "                                                                 Transformer-3-FeedForward-Norm[0]\n",
      "                                                                 Transformer-3-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-MultiHeadSelfAtte (None, None, 768)    0           Transformer-4-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-MultiHeadSelfAtte (None, None, 768)    0           Transformer-3-FeedForward-Norm[0]\n",
      "                                                                 Transformer-4-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-4-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-FeedForward (Feed (None, None, 768)    4722432     Transformer-4-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-FeedForward-Dropo (None, None, 768)    0           Transformer-4-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-FeedForward-Add ( (None, None, 768)    0           Transformer-4-MultiHeadSelfAttent\n",
      "                                                                 Transformer-4-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-4-FeedForward-Norm  (None, None, 768)    1536        Transformer-4-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-4-FeedForward-Norm[0]\n",
      "                                                                 Transformer-4-FeedForward-Norm[0]\n",
      "                                                                 Transformer-4-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-MultiHeadSelfAtte (None, None, 768)    0           Transformer-5-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-MultiHeadSelfAtte (None, None, 768)    0           Transformer-4-FeedForward-Norm[0]\n",
      "                                                                 Transformer-5-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-5-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-FeedForward (Feed (None, None, 768)    4722432     Transformer-5-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-FeedForward-Dropo (None, None, 768)    0           Transformer-5-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-FeedForward-Add ( (None, None, 768)    0           Transformer-5-MultiHeadSelfAttent\n",
      "                                                                 Transformer-5-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-5-FeedForward-Norm  (None, None, 768)    1536        Transformer-5-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-5-FeedForward-Norm[0]\n",
      "                                                                 Transformer-5-FeedForward-Norm[0]\n",
      "                                                                 Transformer-5-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-MultiHeadSelfAtte (None, None, 768)    0           Transformer-6-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-MultiHeadSelfAtte (None, None, 768)    0           Transformer-5-FeedForward-Norm[0]\n",
      "                                                                 Transformer-6-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-6-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-FeedForward (Feed (None, None, 768)    4722432     Transformer-6-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-FeedForward-Dropo (None, None, 768)    0           Transformer-6-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-FeedForward-Add ( (None, None, 768)    0           Transformer-6-MultiHeadSelfAttent\n",
      "                                                                 Transformer-6-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-6-FeedForward-Norm  (None, None, 768)    1536        Transformer-6-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-6-FeedForward-Norm[0]\n",
      "                                                                 Transformer-6-FeedForward-Norm[0]\n",
      "                                                                 Transformer-6-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-MultiHeadSelfAtte (None, None, 768)    0           Transformer-7-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-MultiHeadSelfAtte (None, None, 768)    0           Transformer-6-FeedForward-Norm[0]\n",
      "                                                                 Transformer-7-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-7-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-FeedForward (Feed (None, None, 768)    4722432     Transformer-7-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-FeedForward-Dropo (None, None, 768)    0           Transformer-7-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-FeedForward-Add ( (None, None, 768)    0           Transformer-7-MultiHeadSelfAttent\n",
      "                                                                 Transformer-7-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-7-FeedForward-Norm  (None, None, 768)    1536        Transformer-7-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-7-FeedForward-Norm[0]\n",
      "                                                                 Transformer-7-FeedForward-Norm[0]\n",
      "                                                                 Transformer-7-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-MultiHeadSelfAtte (None, None, 768)    0           Transformer-8-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-MultiHeadSelfAtte (None, None, 768)    0           Transformer-7-FeedForward-Norm[0]\n",
      "                                                                 Transformer-8-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-8-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-FeedForward (Feed (None, None, 768)    4722432     Transformer-8-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-FeedForward-Dropo (None, None, 768)    0           Transformer-8-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-FeedForward-Add ( (None, None, 768)    0           Transformer-8-MultiHeadSelfAttent\n",
      "                                                                 Transformer-8-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-8-FeedForward-Norm  (None, None, 768)    1536        Transformer-8-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-8-FeedForward-Norm[0]\n",
      "                                                                 Transformer-8-FeedForward-Norm[0]\n",
      "                                                                 Transformer-8-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-MultiHeadSelfAtte (None, None, 768)    0           Transformer-9-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-MultiHeadSelfAtte (None, None, 768)    0           Transformer-8-FeedForward-Norm[0]\n",
      "                                                                 Transformer-9-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-9-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-FeedForward (Feed (None, None, 768)    4722432     Transformer-9-MultiHeadSelfAttent\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-FeedForward-Dropo (None, None, 768)    0           Transformer-9-FeedForward[0][0]  \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-FeedForward-Add ( (None, None, 768)    0           Transformer-9-MultiHeadSelfAttent\n",
      "                                                                 Transformer-9-FeedForward-Dropout\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-9-FeedForward-Norm  (None, None, 768)    1536        Transformer-9-FeedForward-Add[0][\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-MultiHeadSelfAtt (None, None, 768)    2362368     Transformer-9-FeedForward-Norm[0]\n",
      "                                                                 Transformer-9-FeedForward-Norm[0]\n",
      "                                                                 Transformer-9-FeedForward-Norm[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-MultiHeadSelfAtt (None, None, 768)    0           Transformer-10-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-MultiHeadSelfAtt (None, None, 768)    0           Transformer-9-FeedForward-Norm[0]\n",
      "                                                                 Transformer-10-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-MultiHeadSelfAtt (None, None, 768)    1536        Transformer-10-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-FeedForward (Fee (None, None, 768)    4722432     Transformer-10-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-FeedForward-Drop (None, None, 768)    0           Transformer-10-FeedForward[0][0] \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-FeedForward-Add  (None, None, 768)    0           Transformer-10-MultiHeadSelfAtten\n",
      "                                                                 Transformer-10-FeedForward-Dropou\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-10-FeedForward-Norm (None, None, 768)    1536        Transformer-10-FeedForward-Add[0]\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-MultiHeadSelfAtt (None, None, 768)    2362368     Transformer-10-FeedForward-Norm[0\n",
      "                                                                 Transformer-10-FeedForward-Norm[0\n",
      "                                                                 Transformer-10-FeedForward-Norm[0\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-MultiHeadSelfAtt (None, None, 768)    0           Transformer-11-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-MultiHeadSelfAtt (None, None, 768)    0           Transformer-10-FeedForward-Norm[0\n",
      "                                                                 Transformer-11-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-MultiHeadSelfAtt (None, None, 768)    1536        Transformer-11-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-FeedForward (Fee (None, None, 768)    4722432     Transformer-11-MultiHeadSelfAtten\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-FeedForward-Drop (None, None, 768)    0           Transformer-11-FeedForward[0][0] \n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-FeedForward-Add  (None, None, 768)    0           Transformer-11-MultiHeadSelfAtten\n",
      "                                                                 Transformer-11-FeedForward-Dropou\n",
      "__________________________________________________________________________________________________\n",
      "Transformer-11-FeedForward-Norm (None, None, 768)    1536        Transformer-11-FeedForward-Add[0]\n",
      "__________________________________________________________________________________________________\n",
      "Pooler (Lambda)                 (None, 768)          0           Transformer-11-FeedForward-Norm[0\n",
      "__________________________________________________________________________________________________\n",
      "Pooler-Dense (Dense)            (None, 768)          590592      Pooler[0][0]                     \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 768)          0           Pooler-Dense[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "dense_73 (Dense)                (None, 2)            1538        dropout_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 102,269,186\n",
      "Trainable params: 102,269,186\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/admin/.local/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:434: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "1/1 [==============================] - 12s 12s/step - loss: 0.6614 - accuracy: 0.4000\n",
      "val_acc: 0.50000, best_val_acc: 0.50000, test_acc: 0.60000\n",
      "\n",
      "Epoch 2/20\n",
      "1/1 [==============================] - 1s 979ms/step - loss: 0.5844 - accuracy: 0.7000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 3/20\n",
      "1/1 [==============================] - 1s 935ms/step - loss: 0.5688 - accuracy: 0.8000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 4/20\n",
      "1/1 [==============================] - 1s 1s/step - loss: 0.4999 - accuracy: 0.9000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 5/20\n",
      "1/1 [==============================] - 1s 1s/step - loss: 0.4265 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 6/20\n",
      "1/1 [==============================] - 1s 987ms/step - loss: 0.4336 - accuracy: 0.7000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 7/20\n",
      "1/1 [==============================] - 1s 944ms/step - loss: 0.3797 - accuracy: 0.8000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 8/20\n",
      "1/1 [==============================] - 1s 947ms/step - loss: 0.3372 - accuracy: 0.8000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 9/20\n",
      "1/1 [==============================] - 1s 927ms/step - loss: 0.3232 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 10/20\n",
      "1/1 [==============================] - 1s 897ms/step - loss: 0.2520 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 11/20\n",
      "1/1 [==============================] - 1s 879ms/step - loss: 0.2555 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 12/20\n",
      "1/1 [==============================] - 1s 893ms/step - loss: 0.2366 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 13/20\n",
      "1/1 [==============================] - 1s 903ms/step - loss: 0.2111 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 14/20\n",
      "1/1 [==============================] - 1s 894ms/step - loss: 0.2102 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 15/20\n",
      "1/1 [==============================] - 1s 907ms/step - loss: 0.2112 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 16/20\n",
      "1/1 [==============================] - 1s 913ms/step - loss: 0.1707 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 17/20\n",
      "1/1 [==============================] - 1s 871ms/step - loss: 0.1551 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 18/20\n",
      "1/1 [==============================] - 1s 915ms/step - loss: 0.1302 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 19/20\n",
      "1/1 [==============================] - 1s 879ms/step - loss: 0.1161 - accuracy: 1.0000\n",
      "val_acc: 0.60000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "Epoch 20/20\n",
      "1/1 [==============================] - 1s 878ms/step - loss: 0.1044 - accuracy: 1.0000\n",
      "val_acc: 0.50000, best_val_acc: 0.60000, test_acc: 0.50000\n",
      "\n",
      "final test acc: 0.500000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 建立分词器\n",
    "tokenizer = Tokenizer(dict_path, do_lower_case=True)\n",
    "\n",
    "\n",
    "class data_generator(DataGenerator):\n",
    "    \"\"\"数据生成器\n",
    "    \"\"\"\n",
    "    def __iter__(self, random=False):\n",
    "        batch_token_ids, batch_segment_ids, batch_labels = [], [], []\n",
    "        for is_end, (text1, text2, label) in self.sample(random):\n",
    "            token_ids, segment_ids = tokenizer.encode(\n",
    "                text1, text2, maxlen=maxlen\n",
    "            )\n",
    "            batch_token_ids.append(token_ids)\n",
    "            batch_segment_ids.append(segment_ids)\n",
    "            batch_labels.append([label])\n",
    "            if len(batch_token_ids) == self.batch_size or is_end:\n",
    "                batch_token_ids = sequence_padding(batch_token_ids)\n",
    "                batch_segment_ids = sequence_padding(batch_segment_ids)\n",
    "                batch_labels = sequence_padding(batch_labels)\n",
    "                yield [batch_token_ids, batch_segment_ids], batch_labels\n",
    "                batch_token_ids, batch_segment_ids, batch_labels = [], [], []\n",
    "\n",
    "\n",
    "# 加载预训练模型\n",
    "bert = build_transformer_model(\n",
    "    config_path=config_path,\n",
    "    checkpoint_path=checkpoint_path,\n",
    "    with_pool=True,\n",
    "    return_keras_model=False,\n",
    ")\n",
    "\n",
    "output = Dropout(rate=0.1)(bert.model.output)\n",
    "output = Dense(\n",
    "    units=2, activation='softmax', kernel_initializer=bert.initializer\n",
    ")(output)\n",
    "\n",
    "model = keras.models.Model(bert.model.input, output)\n",
    "model.summary()\n",
    "\n",
    "model.compile(\n",
    "    loss='sparse_categorical_crossentropy',\n",
    "    optimizer=Adam(2e-5),  # 用足够小的学习率\n",
    "    # optimizer=PiecewiseLinearLearningRate(Adam(5e-5), {10000: 1, 30000: 0.1}),\n",
    "    metrics=['accuracy'],\n",
    ")\n",
    "\n",
    "# 转换数据集\n",
    "train_generator = data_generator(train_data, batch_size)\n",
    "valid_generator = data_generator(valid_data, batch_size)\n",
    "test_generator = data_generator(test_data, batch_size)\n",
    "\n",
    "\n",
    "def evaluate(data):\n",
    "    total, right = 0., 0.\n",
    "    for x_true, y_true in data:\n",
    "        y_pred = model.predict(x_true).argmax(axis=1)\n",
    "        y_true = y_true[:, 0]\n",
    "        total += len(y_true)\n",
    "        right += (y_true == y_pred).sum()\n",
    "    return right / total\n",
    "\n",
    "\n",
    "class Evaluator(keras.callbacks.Callback):\n",
    "    \"\"\"评估与保存\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        self.best_val_acc = 0.\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        val_acc = evaluate(valid_generator)\n",
    "        if val_acc > self.best_val_acc:\n",
    "            self.best_val_acc = val_acc\n",
    "            model.save_weights('best_model.weights')\n",
    "        test_acc = evaluate(test_generator)\n",
    "        print(\n",
    "            u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\\n' %\n",
    "            (val_acc, self.best_val_acc, test_acc)\n",
    "        )\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "\n",
    "    evaluator = Evaluator()\n",
    "\n",
    "    model.fit_generator(\n",
    "        train_generator.forfit(),\n",
    "        steps_per_epoch=len(train_generator),\n",
    "        epochs=20,\n",
    "        callbacks=[evaluator]\n",
    "    )\n",
    "\n",
    "    model.load_weights('best_model.weights')\n",
    "    print(u'final test acc: %05f\\n' % (evaluate(test_generator)))\n",
    "\n",
    "else:\n",
    "\n",
    "    model.load_weights('best_model.weights')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "28b78287-70df-457e-a7e9-3b880c60964a"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "1c1fd79b-0b4d-411a-b39d-264baf0911e0"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "3b33742a-1748-41f5-b163-87ef06e8e758"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "uuid": "6ba43baf-7463-459d-b172-d1d27b1b7bbd"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
